#! /usr/bin/python
# Copyright (C) 2015 Nicira, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import ast
import atexit
import json
import os
import random
import re
import shlex
import subprocess
import sys

import ovs.dirs
import ovs.util
import ovs.daemon
import ovs.vlog

from flask import Flask, jsonify
from flask import request, abort

app = Flask(__name__)
vlog = ovs.vlog.Vlog("ovn-docker-overlay-driver")

OVN_BRIDGE = "br-int"
OVN_NB = ""
PLUGIN_DIR = "/etc/docker/plugins"
PLUGIN_FILE = "/etc/docker/plugins/openvswitch.spec"


def call_popen(cmd):
    child = subprocess.Popen(cmd, stdout=subprocess.PIPE)
    output = child.communicate()
    if child.returncode:
        raise RuntimeError("Fatal error executing %s" % (cmd))
    if len(output) == 0 or output[0] == None:
        output = ""
    else:
        output = output[0].strip()
    return output


def call_prog(prog, args_list):
    cmd = [prog, "--timeout=5", "-vconsole:off"] + args_list
    return call_popen(cmd)


def ovs_vsctl(*args):
    return call_prog("ovs-vsctl", list(args))


def ovn_nbctl(*args):
    args_list = list(args)
    database_option = "%s=%s" % ("--db", OVN_NB)
    args_list.insert(0, database_option)
    return call_prog("ovn-nbctl", args_list)


def cleanup():
    if os.path.isfile(PLUGIN_FILE):
        os.remove(PLUGIN_FILE)


def ovn_init_overlay():
    br_list = ovs_vsctl("list-br").split()
    if OVN_BRIDGE not in br_list:
        ovs_vsctl("--", "--may-exist", "add-br", OVN_BRIDGE,
                  "--", "set", "bridge", OVN_BRIDGE,
                  "external_ids:bridge-id=" + OVN_BRIDGE,
                  "other-config:disable-in-band=true", "fail-mode=secure")

    global OVN_NB
    OVN_NB = ovs_vsctl("get", "Open_vSwitch", ".",
                           "external_ids:ovn-nb").strip('"')
    if not OVN_NB:
        sys.exit("OVN central database's ip address not set")

    ovs_vsctl("set", "open_vswitch", ".",
              "external_ids:ovn-bridge=" + OVN_BRIDGE)


def prepare():
    parser = argparse.ArgumentParser()

    ovs.vlog.add_args(parser)
    ovs.daemon.add_args(parser)
    args = parser.parse_args()
    ovs.vlog.handle_args(args)
    ovs.daemon.handle_args(args)
    ovn_init_overlay()

    if not os.path.isdir(PLUGIN_DIR):
        os.makedirs(PLUGIN_DIR)

    ovs.daemon.daemonize()
    try:
        fo = open(PLUGIN_FILE, "w")
        fo.write("tcp://0.0.0.0:5000")
        fo.close()
    except Exception as e:
        ovs.util.ovs_fatal(0, "Failed to write to spec file (%s)" % str(e),
                           vlog)

    atexit.register(cleanup)


@app.route('/Plugin.Activate', methods=['POST'])
def plugin_activate():
    return jsonify({"Implements": ["NetworkDriver"]})


@app.route('/NetworkDriver.GetCapabilities', methods=['POST'])
def get_capability():
    return jsonify({"Scope": "global"})


@app.route('/NetworkDriver.DiscoverNew', methods=['POST'])
def new_discovery():
    return jsonify({})


@app.route('/NetworkDriver.DiscoverDelete', methods=['POST'])
def delete_discovery():
    return jsonify({})


@app.route('/NetworkDriver.CreateNetwork', methods=['POST'])
def create_network():
    if not request.data:
        abort(400)

    data = json.loads(request.data)

    # NetworkID will have docker generated network uuid and it
    # becomes 'name' in a OVN Logical switch record.
    network = data.get("NetworkID", "")
    if not network:
        abort(400)

    # Limit subnet handling to ipv4 till ipv6 usecase is clear.
    ipv4_data = data.get("IPv4Data", "")
    if not ipv4_data:
        error = "create_network: No ipv4 subnet provided"
        return jsonify({'Err': error})

    subnet = ipv4_data[0].get("Pool", "")
    if not subnet:
        error = "create_network: no subnet in ipv4 data from libnetwork"
        return jsonify({'Err': error})

    gateway_ip = ipv4_data[0].get("Gateway", "").rsplit('/', 1)[0]
    if not gateway_ip:
        error = "create_network: no gateway in ipv4 data from libnetwork"
        return jsonify({'Err': error})

    try:
        ovn_nbctl("lswitch-add", network, "--", "set", "Logical_Switch",
                  network, "external_ids:subnet=" + subnet,
                  "external_ids:gateway_ip=" + gateway_ip)
    except Exception as e:
        error = "create_network: lswitch-add %s" % (str(e))
        return jsonify({'Err': error})

    return jsonify({})


@app.route('/NetworkDriver.DeleteNetwork', methods=['POST'])
def delete_network():
    if not request.data:
        abort(400)

    data = json.loads(request.data)

    nid = data.get("NetworkID", "")
    if not nid:
        abort(400)

    try:
        ovn_nbctl("lswitch-del", nid)
    except Exception as e:
        error = "delete_network: lswitch-del %s" % (str(e))
        return jsonify({'Err': error})

    return jsonify({})


@app.route('/NetworkDriver.CreateEndpoint', methods=['POST'])
def create_endpoint():
    if not request.data:
        abort(400)

    data = json.loads(request.data)

    nid = data.get("NetworkID", "")
    if not nid:
        abort(400)

    eid = data.get("EndpointID", "")
    if not eid:
        abort(400)

    interface = data.get("Interface", "")
    if not interface:
        error = "create_endpoint: no interfaces structure supplied by " \
                "libnetwork"
        return jsonify({'Err': error})

    ip_address_and_mask = interface.get("Address", "")
    if not ip_address_and_mask:
        error = "create_endpoint: ip address not provided by libnetwork"
        return jsonify({'Err': error})

    ip_address = ip_address_and_mask.rsplit('/', 1)[0]
    mac_address_input = interface.get("MacAddress", "")
    mac_address_output = ""

    try:
        ovn_nbctl("lport-add", nid, eid)
    except Exception as e:
        error = "create_endpoint: lport-add (%s)" % (str(e))
        return jsonify({'Err': error})

    if not mac_address_input:
        mac_address = "02:%02x:%02x:%02x:%02x:%02x" % (random.randint(0, 255),
                                                       random.randint(0, 255),
                                                       random.randint(0, 255),
                                                       random.randint(0, 255),
                                                       random.randint(0, 255))
    else:
        mac_address = mac_address_input

    try:
        ovn_nbctl("lport-set-addresses", eid,
                  mac_address + " " + ip_address)
    except Exception as e:
        error = "create_endpoint: lport-set-addresses (%s)" % (str(e))
        return jsonify({'Err': error})

    # Only return a mac address if one did not come as request.
    mac_address_output = ""
    if not mac_address_input:
        mac_address_output = mac_address

    return jsonify({"Interface": {
                                    "Address": "",
                                    "AddressIPv6": "",
                                    "MacAddress": mac_address_output
                                    }})


def get_logical_port_addresses(eid):
    ret = ovn_nbctl("--if-exists", "get", "Logical_port", eid, "addresses")
    if not ret:
        error = "endpoint not found in OVN database"
        return (None, None, error)
    addresses = ast.literal_eval(ret)
    if len(addresses) == 0:
        error = "unexpected return while fetching addresses"
        return (None, None, error)
    (mac_address, ip_address) = addresses[0].split()
    return (mac_address, ip_address, None)


@app.route('/NetworkDriver.EndpointOperInfo', methods=['POST'])
def show_endpoint():
    if not request.data:
        abort(400)

    data = json.loads(request.data)

    nid = data.get("NetworkID", "")
    if not nid:
        abort(400)

    eid = data.get("EndpointID", "")
    if not eid:
        abort(400)

    try:
        (mac_address, ip_address, error) = get_logical_port_addresses(eid)
        if error:
            jsonify({'Err': error})
    except Exception as e:
        error = "show_endpoint: get Logical_port addresses. (%s)" % (str(e))
        return jsonify({'Err': error})

    veth_outside = eid[0:15]
    return jsonify({"Value": {"ip_address": ip_address,
                              "mac_address": mac_address,
                              "veth_outside": veth_outside
                              }})


@app.route('/NetworkDriver.DeleteEndpoint', methods=['POST'])
def delete_endpoint():
    if not request.data:
        abort(400)

    data = json.loads(request.data)

    nid = data.get("NetworkID", "")
    if not nid:
        abort(400)

    eid = data.get("EndpointID", "")
    if not eid:
        abort(400)

    try:
        ovn_nbctl("lport-del", eid)
    except Exception as e:
        error = "delete_endpoint: lport-del %s" % (str(e))
        return jsonify({'Err': error})

    return jsonify({})


@app.route('/NetworkDriver.Join', methods=['POST'])
def network_join():
    if not request.data:
        abort(400)

    data = json.loads(request.data)

    nid = data.get("NetworkID", "")
    if not nid:
        abort(400)

    eid = data.get("EndpointID", "")
    if not eid:
        abort(400)

    sboxkey = data.get("SandboxKey", "")
    if not sboxkey:
        abort(400)

    # sboxkey is of the form: /var/run/docker/netns/CONTAINER_ID
    vm_id = sboxkey.rsplit('/')[-1]

    try:
        (mac_address, ip_address, error) = get_logical_port_addresses(eid)
        if error:
            jsonify({'Err': error})
    except Exception as e:
        error = "network_join: %s" % (str(e))
        return jsonify({'Err': error})

    veth_outside = eid[0:15]
    veth_inside = eid[0:13] + "_c"
    command = "ip link add %s type veth peer name %s" \
              % (veth_inside, veth_outside)
    try:
        call_popen(shlex.split(command))
    except Exception as e:
        error = "network_join: failed to create veth pair (%s)" % (str(e))
        return jsonify({'Err': error})

    command = "ip link set dev %s address %s" \
              % (veth_inside, mac_address)

    try:
        call_popen(shlex.split(command))
    except Exception as e:
        error = "network_join: failed to set veth mac address (%s)" % (str(e))
        return jsonify({'Err': error})

    command = "ip link set %s up" % (veth_outside)

    try:
        call_popen(shlex.split(command))
    except Exception as e:
        error = "network_join: failed to up the veth interface (%s)" % (str(e))
        return jsonify({'Err': error})

    try:
        ovs_vsctl("add-port", OVN_BRIDGE, veth_outside)
        ovs_vsctl("set", "interface", veth_outside,
                  "external_ids:attached-mac=" + mac_address,
                  "external_ids:iface-id=" + eid,
                  "external_ids:vm-id=" + vm_id,
                  "external_ids:iface-status=active")
    except Exception as e:
        error = "network_join: failed to create a port (%s)" % (str(e))
        return jsonify({'Err': error})

    return jsonify({"InterfaceName": {
                                        "SrcName": veth_inside,
                                        "DstPrefix": "eth"
                                     },
                    "Gateway": "",
                    "GatewayIPv6": ""})


@app.route('/NetworkDriver.Leave', methods=['POST'])
def network_leave():
    if not request.data:
        abort(400)

    data = json.loads(request.data)

    nid = data.get("NetworkID", "")
    if not nid:
        abort(400)

    eid = data.get("EndpointID", "")
    if not eid:
        abort(400)

    veth_outside = eid[0:15]
    command = "ip link delete %s" % (veth_outside)
    try:
        call_popen(shlex.split(command))
    except Exception as e:
        error = "network_leave: failed to delete veth pair (%s)" % (str(e))
        return jsonify({'Err': error})

    try:
        ovs_vsctl("--if-exists", "del-port", veth_outside)
    except Exception as e:
        error = "network_leave: failed to delete port (%s)" % (str(e))
        return jsonify({'Err': error})

    return jsonify({})

if __name__ == '__main__':
    prepare()
    app.run(host='0.0.0.0')
